// Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import TensorFlow

/// Caches information states for each player in all states where they are the current player.
/// Establishes a canonical order for a game's possible information states, indexed by the
/// player ID and an information state index.
public class InformationStateCache<Game: GameProtocol> {
  public let game: Game
  /// All real players in `game`, indexed by their player ID.
  public let allPlayers: [Player]
  /// All information states in `game`, indexed by player ID and an information state index.
  public var allInformationStates: [[String]]
  /// For each player, a map from information states to their indices.
  public var informationStateIndices: [[String: Int]]
  /// For each player, a map from information states to a bitvector denoting which actions are legal
  /// in that information state.
  public var legalActionsForInformationState: [[String: [Bool]]]

  public init(_ game: Game) {
    self.game = game
    allPlayers = (0..<game.playerCount).map { (playerID: Int) -> Player in .player(playerID) }
    allInformationStates = allPlayers.map { _ in [] }
    informationStateIndices = allPlayers.map { _ in [:] }
    legalActionsForInformationState = allPlayers.map { _ in [:] }
    populate()
  }

  public func populate() {
    var index = allPlayers.map { _ in 0 }
    forEachState(from: game.initialState) { state in
      let informationState = state.informationStateString(for: state.currentPlayer)
      if case let .player(playerID) = state.currentPlayer,
        informationStateIndices[playerID][informationState] == nil {
        allInformationStates[playerID].append(informationState)
        informationStateIndices[playerID][informationState] = index[playerID]
        legalActionsForInformationState[playerID][informationState]
          = game.allActions.map(state.legalActions.contains)
        index[playerID] += 1
      }
    }
  }
}

/// Calculates state and action values and three kinds of state reach probabilities for a pair of
/// policies, caching intermediate results.
public final class ActionValueCalculator<Game: GameProtocol> {
  public let informationStateCache: InformationStateCache<Game>
  public let game: Game
  /// Probabilities of reaching each information state, conditional on the supplied policies.
  /// Keyed by player ID and information state string.
  public var reachProbabilities: [[String: Double]] = [[:], [:]]
  /// Probabilities of reaching each information state, conditional on the active player always
  /// playing to reach that information state. Keyed by player ID and information state string.
  public var counterfactualReachProbabilities: [[String: Double]] = [[:], [:]]
  /// Probabilities of reaching each information state, conditional on both players always
  /// playing to reach that information state. Keyed by player ID and information state string.
  public var chanceReachProbabilities: [[String: Double]] = [[:], [:]]
  /// Values for each player of every action available in a given information state, keyed by
  /// current player ID and information state string.
  public var actionValues: [[String: [Game.Action: [Double]]]] = [[:], [:]]
  public var nashConv: Double = 0.0

  public init(_ informationStateCache: InformationStateCache<Game>) {
    self.informationStateCache = informationStateCache
    self.game = informationStateCache.game
  }

  func reset() {
    reachProbabilities = [[:], [:]]
    counterfactualReachProbabilities = [[:], [:]]
    chanceReachProbabilities = [[:], [:]]
    actionValues = [[:], [:]]
  }

  /// Compute the value of a state for each player, caching action values and reach probabilities
  /// of child states. The `player` argument controls which player is considered "active" for
  /// purposes of counterfactual reach probabilities.
  func computeValues<P0: StochasticPolicy, P1: StochasticPolicy>(
    of state: Game.State,
    for player: Player,
    player0Policy: P0,
    player1Policy: P1,
    reachProbability: Double = 1.0,
    counterfactualReachProbability: Double = 1.0,
    chanceReachProbability: Double = 1.0
  ) -> [Double] where P0.Game == Game, P1.Game == Game {
    if state.isTerminal {
      return [0, 1].map { state.utility(for: .player($0)) }
    }
    let transitions: [Game.Action: Double]
    switch state.currentPlayer {
    case .chance:
      transitions = state.chanceOutcomes
    case let .player(playerID):
      let infoState = state.informationStateString(for: state.currentPlayer)
      reachProbabilities[playerID][infoState, default: 0.0] += reachProbability
      counterfactualReachProbabilities[playerID][
        infoState, default: 0.0] += counterfactualReachProbability
      chanceReachProbabilities[playerID][infoState, default: 0.0] += chanceReachProbability
      if state.currentPlayer == .player(0) {
        transitions = player0Policy.actionProbabilities(forState: state)
      } else {
        transitions = player1Policy.actionProbabilities(forState: state)
      }
    default:
      transitions = [:]
    }
    var value = [0.0, 0.0]
    for (action, probability) in transitions {
      let child = state.applying(action)
      let childValues = computeValues(
        of: child,
        for: player,
        player0Policy: player0Policy,
        player1Policy: player1Policy,
        reachProbability: reachProbability * probability,
        counterfactualReachProbability: (state.currentPlayer != player ?
          counterfactualReachProbability * probability : counterfactualReachProbability),
        chanceReachProbability: (state.currentPlayer == .chance ?
          chanceReachProbability * probability : chanceReachProbability)
      )
      switch state.currentPlayer {
      case let .player(playerID):
        let infoState = state.informationStateString(for: state.currentPlayer)
        actionValues[playerID][infoState, default: [:]][
          action, default: [0.0, 0.0]][0] += childValues[0] * reachProbability
        actionValues[playerID][infoState, default: [:]][
          action, default: [0.0, 0.0]][1] += childValues[1] * reachProbability
      default:
        break
      }
      value[0] += childValues[0] * probability
      value[1] += childValues[1] * probability
    }
    return value
  }

  /// Compute action values and reach probabilies of each information state for a player against
  /// its policy's best responder policy, in addition to the root state's value to the best
  /// responder.
  /// Returns:
  ///  - `bestResponderValue`, a Double,
  ///  - `actionValuesAgainstBestResponder`, a tensor of shape
  ///    `allInformationStates.count` x `allActions.count`,
  ///  - `counterfactualReachProbability`, a tensor of shape `allInformationStates.count`.
  func actionValuesAgainstBestResponder<Policy: StochasticPolicy>(
    player: Player,
    policy: Policy
  ) -> (
    bestResponderValue: Double,
    actionValuesAgainstBestResponder: Tensor<Double>,
    counterfactualReachProbability: Tensor<Double>
  ) where Policy.Game == Game {
    let opponent: Player = player == .player(0) ? .player(1) : .player(0)
    let bestResponder = BestResponse<Game, Policy>(game: game, player: opponent, policy: policy)
    let bestResponderValue = bestResponder.value(game.initialState)
    reset()
    switch player {
    case .player(0):
      _ = computeValues(
        of: game.initialState,
        for: player,
        player0Policy: policy,
        player1Policy: bestResponder)
    case .player(1):
      _ = computeValues(
        of: game.initialState,
        for: player,
        player0Policy: bestResponder,
        player1Policy: policy)
    default:
      break
    }
    let allActions = game.allActions
    var actionValuesAgainstBestResponder = Tensor<Double>(zeros: [0, allActions.count])
    var counterfactualReachProbability = Tensor<Double>(zeros: [0])
    switch player {
    case let .player(playerID):
      for informationState in informationStateCache.allInformationStates[playerID] {
        let actionValue = actionValues[playerID][informationState, default: [:]]
        let reachProbability = reachProbabilities[playerID][informationState, default: 0.0]
        actionValuesAgainstBestResponder = actionValuesAgainstBestResponder ++
          Tensor(
            shape: [1, allActions.count],
            scalars: allActions.map { action in
              if let value = actionValue[action] {
                return value[playerID] / reachProbability
              } else {
                return 0.0
              }
            })
        counterfactualReachProbability = counterfactualReachProbability ++
          Tensor(
            shape: [1],
            scalars: [Double(
              counterfactualReachProbabilities[playerID][informationState, default: 0.0])])
      }
    default:
      break
    }
    return (bestResponderValue, actionValuesAgainstBestResponder, counterfactualReachProbability)
  }

  /// The exploitability descent loss for a policy.
  @differentiable(wrt: policy)
  public func loss(for policy: TensorFlowTabularPolicy<Game>) -> Double {
    let (nashConv0, qValues0, counterfactualReachProbabilities0) =
      actionValuesAgainstBestResponder(
        player: .player(0),
        policy: withoutDerivative(at: policy))
    let (nashConv1, qValues1, counterfactualReachProbabilities1) =
      actionValuesAgainstBestResponder(
        player: .player(1),
        policy: withoutDerivative(at: policy))
    nashConv = nashConv0 + nashConv1
    let qValues = qValues0 ++ qValues1
    let counterfactualReachProbabilities = counterfactualReachProbabilities0 ++
      counterfactualReachProbabilities1
    let policyValues = policy.probabilities[0] ++ policy.probabilities[1]
    let baseline = (withoutDerivative(at: policyValues) * qValues).sum(alongAxes: 1)
    let advantage = qValues - baseline
    let lossPerInformationState = -(policyValues * advantage).sum(squeezingAxes: 1)
    return (lossPerInformationState * counterfactualReachProbabilities).sum().scalarized()
  }
}
